{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# GPU Operator Customization with CuPy\n",
    "\n",
    "[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/operator_custom_with_cupy.ipynb)\n",
    "[![Open in Kaggle](https://kaggle.com/static/images/open-in-kaggle.svg)](https://kaggle.com/kernels/welcome?src=https://github.com/brainpy/brainpy/blob/master/docs_version2/tutorial_advanced/operator_custom_with_cupy.ipynb)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This functionality is only available for ``brainpylib>=0.3.1``. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## English Version"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Although we can now use the flexible taichi custom operator approach, taichi on cuda does not have more fine-grained control or optimization for some scenarios. So for such scenarios, we can use cupy's \n",
    "- [`RawModule`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n",
    "- [`jit.rawkernel`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n",
    "\n",
    "to compile and run CUDA native code directly as strings or cupy JIT function in real time for finer grained control.\n",
    "\n",
    "Start by importing the relevant Python package."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import brainpy.math as bm\n",
    "\n",
    "import jax\n",
    "import cupy as cp\n",
    "from cupyx import jit\n",
    "\n",
    "bm.set_platform('gpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CuPy RawModule\n",
    "\n",
    "For dealing a large raw CUDA source or loading an existing CUDA binary, the RawModule class can be more handy. It can be initialized either by a CUDA source code. The needed kernels can then be retrieved by calling the get_function() method, which returns a RawKernel instance that can be invoked as discussed above.\n",
    "\n",
    "Be aware that the order of parameters in the kernel function you want to call should **keep outputs at the end of the parameter list**."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_code = r'''\n",
    "    extern \"C\"{\n",
    "\n",
    "    __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n",
    "    {\n",
    "        unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n",
    "        if (tid < N)\n",
    "        {\n",
    "            y[tid] = x1[tid] + x2[tid];\n",
    "        }\n",
    "    }\n",
    "    }\n",
    "'''\n",
    "mod = cp.RawModule(code=source_code)\n",
    "kernel = mod.get_function('kernel')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After define the `RawModule` and get the kernel function. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**).\n",
    "\n",
    "Specify the outs parameter when calling, using jax.ShapeDtypeStruct to define the shape and data type of the output."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare inputs\n",
    "N = 10\n",
    "x1 = bm.ones((N, N))\n",
    "x2 = bm.ones((N, N))\n",
    "\n",
    "# register the kernel as a custom op\n",
    "prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n",
    "\n",
    "# call the custom op\n",
    "y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CuPy JIT RawKernel\n",
    "The cupyx.jit.rawkernel decorator can create raw CUDA kernels from Python functions.\n",
    "\n",
    "In this section, a Python function wrapped with the decorator is called a target function.\n",
    "\n",
    "Here is a short example for how to write a cupyx.jit.rawkernel to copy the values from x to y using a grid-stride loop:\n",
    "\n",
    "Launching a CUDA kernel on a GPU with pre-determined grid/block sizes requires basic understanding in the CUDA Programming Model. And the compilation will be deferred until the first function call. CuPy’s JIT compiler infers the types of arguments at the call time, and will cache the compiled kernels for speeding up any subsequent calls."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit.rawkernel()\n",
    "def elementwise_copy(x, size, y):\n",
    "    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n",
    "    ntid = jit.gridDim.x * jit.blockDim.x\n",
    "    for i in range(tid, size, ntid):\n",
    "        y[i] = x[i]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "After define the `jit.rawkernel`. You can use `bm.XLACustomOp` to register it into it's `gpu_kernel` and call it with the appropriate `gird` and `block` you want (**Here these two parameters both should be Tuple**)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# prepare inputs\n",
    "size = 100\n",
    "x = bm.ones((size,))\n",
    "\n",
    "# register the kernel as a custom op\n",
    "prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n",
    "\n",
    "# call the custom op\n",
    "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 中文版\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "尽管我们现在可以使用灵活的taichi自定义操作符方法，但在cuda后端上，taichi没有更细粒度的控制或某些场景下的优化。因此，对于这类场景，我们可以使用cupy的\n",
    "- [`RawModule`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#raw-kernels)\n",
    "- [`jit.rawkernel`](https://docs.cupy.dev/en/stable/user_guide/kernel.html#jit-kernel-definition) \n",
    "\n",
    "来直接作为字符串或cupy JIT函数实时编译并运行CUDA原生代码，以实现更细致的控制。\n",
    "\n",
    "首先，导入相关的Python包。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import brainpy.math as bm\n",
    "\n",
    "import jax\n",
    "import cupy as cp\n",
    "from cupyx import jit\n",
    "\n",
    "bm.set_platform('gpu')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CuPy RawModule\n",
    "`RawModule`类可以通过传入CUDA源码的字符串来初始化，然后，通过调用`get_function()`方法可以检索所需的kernel，该方法返回一个可以调用的RawKernel实例。\n",
    "\n",
    "请注意，您想要调用的kernel中的参数顺序应该**将输出参数放在参数列表的末尾**。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "source_code = '''\n",
    "    extern \"C\"{\n",
    "\n",
    "    __global__ void kernel(const float* x1, const float* x2, unsigned int N, float* y)\n",
    "    {\n",
    "        unsigned int tid = blockDim.x * blockIdx.x + threadIdx.x;\n",
    "        if (tid < N)\n",
    "        {\n",
    "            y[tid] = x1[tid] + x2[tid];\n",
    "        }\n",
    "    }\n",
    "    }\n",
    "'''\n",
    "mod = cp.RawModule(code=source_code)\n",
    "kernel = mod.get_function('kernel')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义了RawModule并获取了内核函数后，可以使用`bm.XLACustomOp`将其注册到其`gpu_kernel`中，并使用您想要的适当的`grid`和`block`调用它（**这里这两个参数都应该是元组**）。\n",
    "\n",
    "最后在调用中指定`outs`参数，用`jax.ShapeDtypeStruct`来指定输出的形状和数据类型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备输入\n",
    "N = 10\n",
    "x1 = bm.ones((N, N))\n",
    "x2 = bm.ones((N, N))\n",
    "\n",
    "# 将kernel注册为自定义算子\n",
    "prim1 = bm.XLACustomOp(gpu_kernel=kernel)\n",
    "\n",
    "# 调用自定义算子\n",
    "y = prim1(x1, x2, N**2, grid=(N,), block=(N,), outs=[jax.ShapeDtypeStruct((N, N), dtype=bm.float32)])[0]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### CuPy JIT RawKernel\n",
    "\n",
    "`cupyx.jit.rawkernel`装饰器可以从Python函数创建原生CUDA内核。\n",
    "\n",
    "以下是一个如何通过`cupyx.jit.rawkernel`来使用grid-stride循环从`x`复制值到`y`的简短示例：\n",
    "\n",
    "在GPU上启动CUDA内核，需要预先确定的grid/block大小，这需要对CUDA编程模型有基本的了解。编译将延迟到第一次函数调用时。CuPy的JIT编译器会在调用时推断参数的类型，并将缓存编译后的内核以加速任何后续调用。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@jit.rawkernel()\n",
    "def elementwise_copy(x, size, y):\n",
    "    tid = jit.blockIdx.x * jit.blockDim.x + jit.threadIdx.x\n",
    "    ntid = jit.gridDim.x * jit.blockDim.x\n",
    "    for i in range(tid, size, ntid):\n",
    "        y[i] = x[i]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义了`jit.rawkernel`后，您可以使用`bm.XLACustomOp`将其注册到其`gpu_kernel`中，并使用您想要的适当的`grid`和`block`调用它（**这里这两个参数都应该是元组**）。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 准备输入\n",
    "size = 100\n",
    "x = bm.ones((size,))\n",
    "\n",
    "# 将kernel注册为自定义算子\n",
    "prim2 = bm.XLACustomOp(gpu_kernel=elementwise_copy)\n",
    "\n",
    "# 调用自定义算子\n",
    "y = prim2(x, size, grid=(10,), block=(10,), outs=[jax.ShapeDtypeStruct((size,), dtype=bm.float32)])[0]\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
